{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# =============================================================================\n",
    "# Copyright (c) 2020 NVIDIA. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# =============================================================================\n",
    "\n",
    "from torch import optim\n",
    "\n",
    "from nemo.core import DeviceType, NeuralGraph, NeuralModuleFactory, OperationMode\n",
    "from nemo.utils import logging\n",
    "\n",
    "from nemo.collections.cv.modules.data_layers import STL10DataLayer\n",
    "from nemo.collections.cv.modules.losses import NLLLoss\n",
    "from nemo.collections.cv.modules.non_trainables import NonLinearity, ReshapeTensor\n",
    "from nemo.collections.cv.modules.trainables import ImageEncoder, FeedForwardNetwork\n",
    "\n",
    "# WARNING: setting device to CPU to make sure that the notebook will be able to be\n",
    "# executed on every machine. However, the training on CPU will be extremely slow,\n",
    "# so it is strongly suggested to set device to DeviceType.GPU instead.\n",
    "device = DeviceType.CPU\n",
    "\n",
    "# Create Neural(Module)Factory - use the indicated device.\n",
    "nf = NeuralModuleFactory(placement=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tutorial III: Custrom Training\n",
    "\n",
    "In this third part of the Neural Graphs (NGs) tutorial we will focus on a different example: training of an image classification model with a ResNet-50 backbone on the STL-10 dataset using a custom training loop.\n",
    "\n",
    "#### This part covers the following:\n",
    "\n",
    " * how to create separate graphs for training and evaluation\n",
    " * how to move graph between CPU/GPU devices\n",
    " * how to parametrize data loaders\n",
    " * how to write a custom training loop using relying on graph actions\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate data layers for training and validation.\n",
    "dl_train = STL10DataLayer(height=224, width=224, split=\"train\")\n",
    "dl_valid = STL10DataLayer(height=224, width=224, split=\"test\")\n",
    "\n",
    "# Instantiate the loss module.\n",
    "nll_loss = NLLLoss()\n",
    "\n",
    "# This may take a while, as the dataset has to be downloaded and verified..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate modules forming the \"model\" - use pretrained ResNet-50 as image encoder.\n",
    "image_encoder = ImageEncoder(model_type=\"resnet50\", pretrained=True, return_feature_maps=True)\n",
    "reshaper = ReshapeTensor(input_sizes=[-1, 7, 7, 2048], output_sizes=[-1, 100352])\n",
    "ffn = FeedForwardNetwork(input_size=100352, output_size=10, hidden_sizes=[100, 100], dropout_rate=0.1)\n",
    "nl = NonLinearity(type=\"logsoftmax\", sizes=[-1, 10])\n",
    "\n",
    "# Freeze the encoder - to make the training faster.\n",
    "image_encoder.freeze()\n",
    "\n",
    "# This also might take some time, as we need to download the pretrained checkpoint..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the \"model graph\".\n",
    "with NeuralGraph(operation_mode=OperationMode.both) as stl10_resnet_classifier:\n",
    "    # Bind the inputs to encoder.\n",
    "    feat_map = image_encoder(inputs=stl10_resnet_classifier)\n",
    "    res_img = reshaper(inputs=feat_map)\n",
    "    logits = ffn(inputs=res_img)\n",
    "    preds = nl(inputs=logits)\n",
    "    # Cherry-pick outputs.\n",
    "    stl10_resnet_classifier.outputs[\"predictions\"] = preds\n",
    "    \n",
    "# Ok, let us see what the graph looks like now.\n",
    "logging.info(stl10_resnet_classifier.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let us now compose a training graph...\n",
    "with NeuralGraph(operation_mode=OperationMode.training) as training_graph:\n",
    "    # Take outputs from the data layer.\n",
    "    _, x, t, _ = dl_train()\n",
    "    # Pass the images to the model.\n",
    "    p = stl10_resnet_classifier(inputs=x)\n",
    "    # Calculate the loss.\n",
    "    lss = nll_loss(predictions=p, targets=t)\n",
    "\n",
    "# Ok, let us see what the graph looks like now.\n",
    "logging.info(training_graph.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ... and a validation graph.\n",
    "with NeuralGraph(operation_mode=OperationMode.evaluation) as validation_graph:\n",
    "    # Take outputs from the data layer.\n",
    "    _, x_valid, t_valid, _ = dl_valid()\n",
    "    # Pass them to the trainable module.\n",
    "    p_valid = stl10_resnet_classifier(inputs=x_valid)\n",
    "    # Calculate the loss.\n",
    "    loss_valid = nll_loss(predictions=p_valid, targets=t_valid)\n",
    "\n",
    "# This is how it looks now.\n",
    "logging.info(validation_graph.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Perform operations on the indicated device.\n",
    "training_graph.to(device)\n",
    "validation_graph.to(device)\n",
    "\n",
    "# Create the optimizer.\n",
    "opt = optim.Adam(training_graph.parameters(), lr=0.001)\n",
    "\n",
    "# Print frequency.\n",
    "freq = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Finally, construct and run the custom training loop.\n",
    "\n",
    "# Train for 5 epochs.\n",
    "for epoch in range(5):\n",
    "    # Configure data loader used by the training graph - once per epoch.\n",
    "    # Use default settings - just change the batch_size and turn sample shuffling on.\n",
    "    training_graph.configure_data_loader(batch_size=64, shuffle=True)\n",
    "\n",
    "    # Iterate over the whole dataset - in batches.\n",
    "    for step, batch in enumerate(training_graph.get_batch()):\n",
    "\n",
    "        # Reset the gradients.\n",
    "        opt.zero_grad()\n",
    "\n",
    "        # Forward pass.\n",
    "        outputs = training_graph.forward(batch)\n",
    "        # Print loss.\n",
    "        if step % freq == 0:\n",
    "            logging.info(\"Epoch: {} Step: {} Training Loss: {}\".format(epoch, step, outputs.loss))\n",
    "\n",
    "        # Backpropagate the gradients.\n",
    "        training_graph.backward()\n",
    "\n",
    "        # Update the parameters.\n",
    "        opt.step()\n",
    "    # Epoch ended.\n",
    "\n",
    "    # Evaluate graph on test set.\n",
    "    # Configure data loader used by the validation graph - once per epoch.\n",
    "    valid_losses = []\n",
    "    validation_graph.configure_data_loader(batch_size=64)\n",
    "    # Iterate over the whole dataset - in batches.\n",
    "    for step, batch in enumerate(validation_graph.get_batch()):\n",
    "        # Forward pass.\n",
    "        outputs = validation_graph.forward(batch)\n",
    "        valid_losses.append(outputs.loss)\n",
    "    # Print avgerage loss.\n",
    "    logging.info(\"Epoch: {} Avg. Validation Loss: {}\".format(epoch, sum(valid_losses) / len(valid_losses)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nemo-env",
   "language": "python",
   "name": "nemo-env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
